//
//  GridSampleGradTest.cpp
//  MNNTests
//
//  Created by MNN on 2022/09/07.
//  Copyright © 2018, Alibaba Group Holding Limited
//

#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
#include "../tools/train/source/grad/OpGrad.hpp"

using namespace MNN;
using namespace MNN::Express;

class GridSampleGradTest : public MNNTestCase {
public:
    char name[20] = "GridSample";
    virtual ~GridSampleGradTest() = default;

    virtual bool run(int precision) {
        std::vector<int> shape = {2, 3, 2, 3};
        const int inputLen = shape[0] * shape[1] * shape[2] * shape[3];
        auto input = _Input(shape, NCHW);
        std::vector<float> inpuData = { 0.5500, 0.6721, 0.4343, 0.8518, 0.9456, 0.6444, 0.5927, 0.4439, 0.9329,
                                        0.1434, 0.6933, 0.0180, 0.3173, 0.2903, 0.4159, 0.8706, 0.1812, 0.5890,
                                        0.3834, 0.0335, 0.9997, 0.7504, 0.5379, 0.9836, 0.3202, 0.4824, 0.9982,
                                        0.8029, 0.2889, 0.8386, 0.2282, 0.6912, 0.2678, 0.9031, 0.7055, 0.9389};
        auto inputPtr = input->writeMap<float>();
        memcpy(inputPtr, inpuData.data(), inputLen * sizeof(float));

        std::vector<int> gridShape = {2, 6, 6, 2};
        const int gridLen = gridShape[0] * gridShape[1] * gridShape[2] * gridShape[3];
        auto grid = _Input(gridShape, NCHW);
        std::vector<float> gridData = { -1.0000, -1.0000, -0.6000, -1.0000, -0.2000, -1.0000,  0.2000, -1.0000,
                                        0.6000, -1.0000,  1.0000, -1.0000, -1.0000, -0.6000, -0.6000, -0.6000,
                                        -0.2000, -0.6000,  0.2000, -0.6000,  0.6000, -0.6000,  1.0000, -0.6000,
                                        -1.0000, -0.2000, -0.6000, -0.2000, -0.2000, -0.2000,  0.2000, -0.2000,
                                        0.6000, -0.2000,  1.0000, -0.2000, -1.0000,  0.2000, -0.6000,  0.2000,
                                        -0.2000,  0.2000,  0.2000,  0.2000,  0.6000,  0.2000,  1.0000,  0.2000,
                                        -1.0000,  0.6000, -0.6000,  0.6000, -0.2000,  0.6000,  0.2000,  0.6000,
                                        0.6000,  0.6000,  1.0000,  0.6000, -1.0000,  1.0000, -0.6000,  1.0000,
                                        -0.2000,  1.0000,  0.2000,  1.0000,  0.6000,  1.0000,  1.0000,  1.0000,
                                        -1.0000, -1.0000, -0.6000, -1.0000, -0.2000, -1.0000,  0.2000, -1.0000,
                                        0.6000, -1.0000,  1.0000, -1.0000, -1.0000, -0.6000, -0.6000, -0.6000,
                                        -0.2000, -0.6000,  0.2000, -0.6000,  0.6000, -0.6000,  1.0000, -0.6000,
                                        -1.0000, -0.2000, -0.6000, -0.2000, -0.2000, -0.2000,  0.2000, -0.2000,
                                        0.6000, -0.2000,  1.0000, -0.2000, -1.0000,  0.2000, -0.6000,  0.2000,
                                        -0.2000,  0.2000,  0.2000,  0.2000,  0.6000,  0.2000,  1.0000,  0.2000,
                                        -1.0000,  0.6000, -0.6000,  0.6000, -0.2000,  0.6000,  0.2000,  0.6000,
                                        0.6000,  0.6000,  1.0000,  0.6000, -1.0000,  1.0000, -0.6000,  1.0000,
                                        -0.2000,  1.0000,  0.2000,  1.0000,  0.6000,  1.0000,  1.0000,  1.0000};
        auto gridPtr = grid->writeMap<float>();
        memcpy(gridPtr, gridData.data(), gridLen * sizeof(float));
        

        // TODO: inference of this arguments combination is wrong
        auto mode = InterpolationMethod::NEAREST;
        auto paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_ZEROS;
        auto alignCorners = false;
        auto output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        auto outputNCHW = _Convert(output, NCHW);

        auto outputPtr = outputNCHW->readMap<float>();
        const int outputLen = shape[0] * shape[1] * gridShape[1] * gridShape[2];

        std::vector<float> outputTorch = {  0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721,
                                            0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343,
                                            0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456,
                                            0.9456, 0.6444, 0.6444, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                            0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                                            0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329,
                                            0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933,
                                            0.6933, 0.0180, 0.0180, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                            0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903,
                                            0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159,
                                            0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812,
                                            0.1812, 0.5890, 0.5890, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                            0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335,
                                            0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                                            0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379,
                                            0.5379, 0.9836, 0.9836, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                            0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824,
                                            0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982,
                                            0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                                            0.2889, 0.8386, 0.8386, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                            0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912,
                                            0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                                            0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055,
                                            0.7055, 0.9389, 0.9389, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                // return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        auto opExpr = output->expr().first;
        auto grad = OpGrad::get(opExpr->get()->type());
        std::vector<float> outputDiff = {   0.4334, 0.0341, 0.2641, 0.7900, 0.5033, 0.5753, 0.0180, 0.7617, 0.3299,
                                            0.9388, 0.7645, 0.9430, 0.7312, 0.1058, 0.4154, 0.7537, 0.3959, 0.2058,
                                            0.9084, 0.9178, 0.6203, 0.3946, 0.5286, 0.0053, 0.5959, 0.4805, 0.3953,
                                            0.8146, 0.8543, 0.1426, 0.4022, 0.8199, 0.7822, 0.3160, 0.5057, 0.8435,
                                            0.5449, 0.0964, 0.4719, 0.3557, 0.1786, 0.8186, 0.0859, 0.2833, 0.5462,
                                            0.1870, 0.9203, 0.1523, 0.3556, 0.9206, 0.2185, 0.5502, 0.8321, 0.5941,
                                            0.3160, 0.0663, 0.8522, 0.8215, 0.3595, 0.2714, 0.6255, 0.9103, 0.2248,
                                            0.4765, 0.2330, 0.4213, 0.3474, 0.7129, 0.1307, 0.2414, 0.7421, 0.1453,
                                            0.5165, 0.7668, 0.1646, 0.0379, 0.1988, 0.1783, 0.3200, 0.5802, 0.7501,
                                            0.5057, 0.9157, 0.2080, 0.9982, 0.6694, 0.3964, 0.3710, 0.9381, 0.9157,
                                            0.2548, 0.2127, 0.8212, 0.5140, 0.3528, 0.2028, 0.9128, 0.3492, 0.9882,
                                            0.0330, 0.4107, 0.5150, 0.8750, 0.1118, 0.1271, 0.5068, 0.4232, 0.0709,
                                            0.2711, 0.0418, 0.5329, 0.8123, 0.0119, 0.1818, 0.2264, 0.6342, 0.1863,
                                            0.8303, 0.2253, 0.8572, 0.0520, 0.0255, 0.4094, 0.3164, 0.2758, 0.5764,
                                            0.7998, 0.7261, 0.9420, 0.8043, 0.7131, 0.3567, 0.1961, 0.3868, 0.4668,
                                            0.8830, 0.8475, 0.2829, 0.0681, 0.3372, 0.9303, 0.0397, 0.0962, 0.3651,
                                            0.1226, 0.7876, 0.4374, 0.1730, 0.2058, 0.7499, 0.8105, 0.5794, 0.7401,
                                            0.3478, 0.8476, 0.8795, 0.7856, 0.6042, 0.4180, 0.4664, 0.8128, 0.6839,
                                            0.9811, 0.7328, 0.9305, 0.1411, 0.4011, 0.4810, 0.5414, 0.6038, 0.1644,
                                            0.1686, 0.2125, 0.1554, 0.8285, 0.6496, 0.0667, 0.7326, 0.9510, 0.1087,
                                            0.4501, 0.8744, 0.3976, 0.8691, 0.7303, 0.2784, 0.4464, 0.8928, 0.6532,
                                            0.4175, 0.5971, 0.7475, 0.1091, 0.3149, 0.3717, 0.5579, 0.5649, 0.6624,
                                            0.8024, 0.1316, 0.8202, 0.7971, 0.6213, 0.9040, 0.9452, 0.9925, 0.4661,
                                            0.7995, 0.0764, 0.0370, 0.5322, 0.2354, 0.1298, 0.0324, 0.8321, 0.6498};
        auto inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        std::vector<float> expectedOutput = {   2.0842, 3.4919, 3.3878, 2.9026, 2.2248, 1.5308, 2.2867, 2.3295, 3.4960,
                                                1.9181, 2.3750, 1.2852, 3.8511, 2.2257, 3.3546, 1.7295, 2.3564, 1.4813,
                                                1.2510, 3.0876, 2.1284, 2.1088, 3.0961, 2.2002, 3.6899, 2.5827, 4.1795,
                                                2.8591, 1.4046, 1.2500, 3.0877, 3.2670, 3.5806, 2.8717, 2.8829, 1.6387};
        auto tmpgotOutput = _Convert(inputGrad[0], NCHW);
        auto gotOutput = tmpgotOutput->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                // return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::NEAREST;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_ZEROS;
        alignCorners = true;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721,
                        0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343,
                        0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456,
                        0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444,
                        0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                        0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329,
                        0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933,
                        0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
                        0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903,
                        0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159,
                        0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812,
                        0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890,
                        0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335,
                        0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                        0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379,
                        0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836,
                        0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824,
                        0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982,
                        0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                        0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386,
                        0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912,
                        0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                        0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055,
                        0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.0842, 3.4919, 3.3878, 4.1247, 3.3230, 2.8800, 2.2867, 2.3295, 3.4960,
                            2.9784, 2.7471, 2.1726, 3.8511, 2.2257, 3.3546, 2.7163, 2.9903, 1.9754,
                            1.2510, 3.0876, 2.1284, 2.5141, 4.0661, 2.6615, 3.6899, 2.5827, 4.1795,
                            4.3372, 2.2039, 2.3097, 3.0877, 3.2670, 3.5806, 3.6393, 3.0451, 3.1206};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::NEAREST;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_BORDER;
        alignCorners = false;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721,
                        0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343,
                        0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456,
                        0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444,
                        0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                        0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329,
                        0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933,
                        0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
                        0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903,
                        0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159,
                        0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812,
                        0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890,
                        0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335,
                        0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                        0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379,
                        0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836,
                        0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824,
                        0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982,
                        0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                        0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386,
                        0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912,
                        0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                        0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055,
                        0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.0842, 3.4919, 3.3878, 4.1247, 3.3230, 2.8800, 2.2867, 2.3295, 3.4960,
                            2.9784, 2.7471, 2.1726, 3.8511, 2.2257, 3.3546, 2.7163, 2.9903, 1.9754,
                            1.2510, 3.0876, 2.1284, 2.5141, 4.0661, 2.6615, 3.6899, 2.5827, 4.1795,
                            4.3372, 2.2039, 2.3097, 3.0877, 3.2670, 3.5806, 3.6393, 3.0451, 3.1206};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::NEAREST;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_BORDER;
        alignCorners = true;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721,
                        0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343,
                        0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456,
                        0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444,
                        0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                        0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329,
                        0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933,
                        0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
                        0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903,
                        0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159,
                        0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812,
                        0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890,
                        0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335,
                        0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                        0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379,
                        0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836,
                        0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824,
                        0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982,
                        0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                        0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386,
                        0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912,
                        0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                        0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055,
                        0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.0842, 3.4919, 3.3878, 4.1247, 3.3230, 2.8800, 2.2867, 2.3295, 3.4960,
                            2.9784, 2.7471, 2.1726, 3.8511, 2.2257, 3.3546, 2.7163, 2.9903, 1.9754,
                            1.2510, 3.0876, 2.1284, 2.5141, 4.0661, 2.6615, 3.6899, 2.5827, 4.1795,
                            4.3372, 2.2039, 2.3097, 3.0877, 3.2670, 3.5806, 3.6393, 3.0451, 3.1206};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::NEAREST;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_REFLECTION;
        alignCorners = false;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721,
                        0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343,
                        0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456,
                        0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444,
                        0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                        0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329,
                        0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933,
                        0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
                        0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903,
                        0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159,
                        0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812,
                        0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890,
                        0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335,
                        0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                        0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379,
                        0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836,
                        0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824,
                        0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982,
                        0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                        0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386,
                        0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912,
                        0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                        0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055,
                        0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.0842, 3.4919, 3.3878, 4.1247, 3.3230, 2.8800, 2.2867, 2.3295, 3.4960,
                            2.9784, 2.7471, 2.1726, 3.8511, 2.2257, 3.3546, 2.7163, 2.9903, 1.9754,
                            1.2510, 3.0876, 2.1284, 2.5141, 4.0661, 2.6615, 3.6899, 2.5827, 4.1795,
                            4.3372, 2.2039, 2.3097, 3.0877, 3.2670, 3.5806, 3.6393, 3.0451, 3.1206};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::NEAREST;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_REFLECTION;
        alignCorners = true;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721,
                        0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343,
                        0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456,
                        0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444,
                        0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                        0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329,
                        0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933,
                        0.6933, 0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
                        0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903,
                        0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159, 0.4159,
                        0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812,
                        0.1812, 0.5890, 0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890,
                        0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335,
                        0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                        0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379,
                        0.5379, 0.9836, 0.9836, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836,
                        0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824,
                        0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982,
                        0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                        0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386,
                        0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912,
                        0.6912, 0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                        0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055,
                        0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.0842, 3.4919, 3.3878, 4.1247, 3.3230, 2.8800, 2.2867, 2.3295, 3.4960,
                            2.9784, 2.7471, 2.1726, 3.8511, 2.2257, 3.3546, 2.7163, 2.9903, 1.9754,
                            1.2510, 3.0876, 2.1284, 2.5141, 4.0661, 2.6615, 3.6899, 2.5827, 4.1795,
                            4.3372, 2.2039, 2.3097, 3.0877, 3.2670, 3.5806, 3.6393, 3.0451, 3.1206};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::BILINEAR;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_ZEROS;
        alignCorners = false;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.1375, 0.2811, 0.3177, 0.3004, 0.2290, 0.1086, 0.2475, 0.5060, 0.5719,
                        0.5407, 0.4123, 0.1954, 0.3203, 0.6519, 0.7201, 0.6771, 0.5230, 0.2487,
                        0.3806, 0.7715, 0.8329, 0.7789, 0.6096, 0.2907, 0.3833, 0.7751, 0.8257,
                        0.7697, 0.6071, 0.2900, 0.2130, 0.4306, 0.4587, 0.4276, 0.3373, 0.1611,
                        0.1482, 0.2889, 0.2443, 0.2953, 0.4420, 0.2332, 0.2667, 0.5200, 0.4397,
                        0.5315, 0.7956, 0.4198, 0.2290, 0.4640, 0.5005, 0.5606, 0.6445, 0.3292,
                        0.1391, 0.3122, 0.5164, 0.5207, 0.3251, 0.1462, 0.0645, 0.1786, 0.4755,
                        0.4416, 0.0770, 0.0081, 0.0358, 0.0992, 0.2642, 0.2454, 0.0428, 0.0045,
                        0.0793, 0.1573, 0.1492, 0.1640, 0.2017, 0.1040, 0.1428, 0.2831, 0.2686,
                        0.2952, 0.3630, 0.1872, 0.2416, 0.4607, 0.3253, 0.3206, 0.4468, 0.2339,
                        0.3523, 0.6555, 0.3611, 0.3109, 0.5048, 0.2685, 0.3918, 0.7215, 0.3492,
                        0.2732, 0.4934, 0.2650, 0.2176, 0.4008, 0.1940, 0.1518, 0.2741, 0.1472,
                        0.0958, 0.1742, 0.0692, 0.1617, 0.4515, 0.2499, 0.1725, 0.3136, 0.1246,
                        0.2910, 0.8128, 0.4499, 0.2467, 0.4626, 0.2774, 0.4278, 0.9139, 0.4974,
                        0.3202, 0.6149, 0.4627, 0.5671, 0.9282, 0.4942, 0.3377, 0.6562, 0.5415,
                        0.6044, 0.8451, 0.4426, 0.1876, 0.3646, 0.3008, 0.3358, 0.4695, 0.2459,
                        0.0800, 0.1682, 0.2169, 0.3186, 0.4733, 0.2495, 0.1441, 0.3028, 0.3904,
                        0.5734, 0.8520, 0.4492, 0.2325, 0.4609, 0.4365, 0.5821, 0.8977, 0.4752,
                        0.3290, 0.6270, 0.4403, 0.5088, 0.8325, 0.4432, 0.3613, 0.6763, 0.3988,
                        0.4084, 0.7053, 0.3774, 0.2007, 0.3758, 0.2215, 0.2269, 0.3918, 0.2096,
                        0.0571, 0.1372, 0.2762, 0.2821, 0.1551, 0.0670, 0.1027, 0.2470, 0.4971,
                        0.5078, 0.2791, 0.1205, 0.2153, 0.4572, 0.6160, 0.6276, 0.4918, 0.2346,
                        0.3503, 0.7007, 0.7010, 0.7121, 0.7339, 0.3688, 0.4064, 0.7950, 0.6883,
                        0.6980, 0.8240, 0.4225, 0.2258, 0.4417, 0.3824, 0.3878, 0.4578, 0.2347};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  1.7272, 2.1969, 2.4446, 2.5530, 2.1921, 2.0853, 1.5583, 1.7373, 2.2623,
                            2.1597, 1.8116, 1.7707, 2.3393, 1.8308, 2.2729, 2.0103, 1.8636, 1.6827,
                            1.3534, 1.9779, 1.7254, 1.8226, 2.4405, 2.1251, 2.6567, 1.8687, 2.5248,
                            2.6056, 1.4757, 1.6673, 2.2104, 2.2115, 2.5413, 2.2844, 2.0705, 2.0135};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::BILINEAR;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_ZEROS;
        alignCorners = true;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5988, 0.6477, 0.6245, 0.5294, 0.4343, 0.6104, 0.6569, 0.7035,
                        0.6767, 0.5765, 0.4763, 0.6707, 0.7150, 0.7593, 0.7289, 0.6236, 0.5183,
                        0.7311, 0.7731, 0.8152, 0.7810, 0.6707, 0.5604, 0.7914, 0.8312, 0.8710,
                        0.8332, 0.7178, 0.6024, 0.8518, 0.8893, 0.9268, 0.8854, 0.7649, 0.6444,
                        0.5927, 0.5332, 0.4737, 0.5417, 0.7373, 0.9329, 0.5028, 0.4992, 0.4956,
                        0.5450, 0.6475, 0.7499, 0.4130, 0.4653, 0.5175, 0.5483, 0.5576, 0.5669,
                        0.3231, 0.4313, 0.5395, 0.5516, 0.4678, 0.3840, 0.2333, 0.3973, 0.5614,
                        0.5549, 0.3780, 0.2010, 0.1434, 0.3634, 0.5833, 0.5582, 0.2881, 0.0180,
                        0.3173, 0.3065, 0.2957, 0.3154, 0.3657, 0.4159, 0.4280, 0.3642, 0.3004,
                        0.3049, 0.3777, 0.4505, 0.5386, 0.4218, 0.3051, 0.2944, 0.3897, 0.4851,
                        0.6493, 0.4795, 0.3097, 0.2838, 0.4018, 0.5198, 0.7599, 0.5372, 0.3144,
                        0.2733, 0.4138, 0.5544, 0.8706, 0.5948, 0.3191, 0.2628, 0.4259, 0.5890,
                        0.3834, 0.2434, 0.1035, 0.2267, 0.6132, 0.9997, 0.4568, 0.3278, 0.1989,
                        0.3068, 0.6516, 0.9965, 0.5302, 0.4122, 0.2942, 0.3869, 0.6901, 0.9933,
                        0.6036, 0.4966, 0.3896, 0.4669, 0.7285, 0.9900, 0.6770, 0.5810, 0.4850,
                        0.5470, 0.7669, 0.9868, 0.7504, 0.6654, 0.5804, 0.6270, 0.8053, 0.9836,
                        0.3202, 0.3851, 0.4500, 0.5856, 0.7919, 0.9982, 0.4167, 0.4275, 0.4383,
                        0.5482, 0.7572, 0.9663, 0.5133, 0.4700, 0.4267, 0.5109, 0.7226, 0.9344,
                        0.6098, 0.5124, 0.4150, 0.4735, 0.6880, 0.9024, 0.7064, 0.5549, 0.4034,
                        0.4362, 0.6534, 0.8705, 0.8029, 0.5973, 0.3917, 0.3988, 0.6187, 0.8386,
                        0.2282, 0.4134, 0.5986, 0.6065, 0.4372, 0.2678, 0.3632, 0.4955, 0.6279,
                        0.6357, 0.5188, 0.4020, 0.4982, 0.5777, 0.6572, 0.6648, 0.6005, 0.5362,
                        0.6331, 0.6598, 0.6865, 0.6939, 0.6822, 0.6705, 0.7681, 0.7419, 0.7157,
                        0.7230, 0.7639, 0.8047, 0.9031, 0.8241, 0.7450, 0.7522, 0.8455, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.2920, 3.8967, 2.9875, 3.2304, 4.2240, 2.6610, 2.0140, 3.2235, 2.7043,
                            2.5440, 3.3400, 2.1845, 3.0341, 3.3267, 2.2842, 3.1067, 3.2178, 2.1438,
                            1.6815, 3.5478, 2.3330, 1.9165, 3.9037, 2.3262, 3.2967, 3.4461, 3.1963,
                            3.6989, 3.3385, 2.3265, 2.9128, 4.4006, 3.1326, 3.0053, 3.3945, 2.8945};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::BILINEAR;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_BORDER;
        alignCorners = false;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5622, 0.6355, 0.6008, 0.4581, 0.4343, 0.5500, 0.5622, 0.6355,
                        0.6008, 0.4581, 0.4343, 0.6405, 0.6519, 0.7201, 0.6771, 0.5230, 0.4973,
                        0.7613, 0.7715, 0.8329, 0.7789, 0.6096, 0.5814, 0.8518, 0.8612, 0.9175,
                        0.8552, 0.6745, 0.6444, 0.8518, 0.8612, 0.9175, 0.8552, 0.6745, 0.6444,
                        0.5927, 0.5778, 0.4885, 0.5906, 0.8840, 0.9329, 0.5927, 0.5778, 0.4885,
                        0.5906, 0.8840, 0.9329, 0.4579, 0.4640, 0.5005, 0.5606, 0.6445, 0.6584,
                        0.2782, 0.3122, 0.5164, 0.5207, 0.3251, 0.2925, 0.1434, 0.1984, 0.5283,
                        0.4907, 0.0855, 0.0180, 0.1434, 0.1984, 0.5283, 0.4907, 0.0855, 0.0180,
                        0.3173, 0.3146, 0.2984, 0.3280, 0.4033, 0.4159, 0.3173, 0.3146, 0.2984,
                        0.3280, 0.4033, 0.4159, 0.4833, 0.4607, 0.3253, 0.3206, 0.4468, 0.4678,
                        0.7046, 0.6555, 0.3611, 0.3109, 0.5048, 0.5371, 0.8706, 0.8017, 0.3880,
                        0.3035, 0.5482, 0.5890, 0.8706, 0.8017, 0.3880, 0.3035, 0.5482, 0.5890,
                        0.3834, 0.3484, 0.1385, 0.3234, 0.9031, 0.9997, 0.3834, 0.3484, 0.1385,
                        0.3234, 0.9031, 0.9997, 0.4935, 0.4626, 0.2774, 0.4278, 0.9139, 0.9949,
                        0.6403, 0.6149, 0.4627, 0.5671, 0.9282, 0.9884, 0.7504, 0.7291, 0.6016,
                        0.6716, 0.9390, 0.9836, 0.7504, 0.7291, 0.6016, 0.6716, 0.9390, 0.9836,
                        0.3202, 0.3364, 0.4337, 0.6371, 0.9466, 0.9982, 0.3202, 0.3364, 0.4337,
                        0.6371, 0.9466, 0.9982, 0.4650, 0.4609, 0.4365, 0.5821, 0.8977, 0.9503,
                        0.6581, 0.6270, 0.4403, 0.5088, 0.8325, 0.8865, 0.8029, 0.7515, 0.4431,
                        0.4538, 0.7836, 0.8386, 0.8029, 0.7515, 0.4431, 0.4538, 0.7836, 0.8386,
                        0.2282, 0.2745, 0.5523, 0.5642, 0.3101, 0.2678, 0.2282, 0.2745, 0.5523,
                        0.5642, 0.3101, 0.2678, 0.4307, 0.4572, 0.6160, 0.6276, 0.4918, 0.4691,
                        0.7006, 0.7007, 0.7010, 0.7121, 0.7339, 0.7376, 0.9031, 0.8833, 0.7648,
                        0.7755, 0.9156, 0.9389, 0.9031, 0.8833, 0.7648, 0.7755, 0.9156, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.5877, 2.6968, 3.9095, 4.1514, 2.7408, 3.2053, 2.3421, 2.1040, 3.4308,
                            3.3573, 2.0751, 2.7010, 3.7347, 2.0529, 3.0647, 3.5380, 2.1913, 2.5319,
                            1.9809, 2.5312, 2.7609, 2.6094, 2.9085, 2.9178, 4.1109, 2.2224, 4.0877,
                            4.3476, 1.8670, 2.6672, 3.5051, 2.8249, 4.0540, 3.7293, 2.2799, 3.3471};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::BILINEAR;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_BORDER;
        alignCorners = true;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5988, 0.6477, 0.6245, 0.5294, 0.4343, 0.6104, 0.6569, 0.7035,
                        0.6767, 0.5765, 0.4763, 0.6707, 0.7150, 0.7593, 0.7289, 0.6236, 0.5183,
                        0.7311, 0.7731, 0.8152, 0.7810, 0.6707, 0.5604, 0.7914, 0.8312, 0.8710,
                        0.8332, 0.7178, 0.6024, 0.8518, 0.8893, 0.9268, 0.8854, 0.7649, 0.6444,
                        0.5927, 0.5332, 0.4737, 0.5417, 0.7373, 0.9329, 0.5028, 0.4992, 0.4956,
                        0.5450, 0.6475, 0.7499, 0.4130, 0.4653, 0.5175, 0.5483, 0.5576, 0.5669,
                        0.3231, 0.4313, 0.5395, 0.5516, 0.4678, 0.3840, 0.2333, 0.3973, 0.5614,
                        0.5549, 0.3780, 0.2010, 0.1434, 0.3634, 0.5833, 0.5582, 0.2881, 0.0180,
                        0.3173, 0.3065, 0.2957, 0.3154, 0.3657, 0.4159, 0.4280, 0.3642, 0.3004,
                        0.3049, 0.3777, 0.4505, 0.5386, 0.4218, 0.3051, 0.2944, 0.3897, 0.4851,
                        0.6493, 0.4795, 0.3097, 0.2838, 0.4018, 0.5198, 0.7599, 0.5372, 0.3144,
                        0.2733, 0.4138, 0.5544, 0.8706, 0.5948, 0.3191, 0.2628, 0.4259, 0.5890,
                        0.3834, 0.2434, 0.1035, 0.2267, 0.6132, 0.9997, 0.4568, 0.3278, 0.1989,
                        0.3068, 0.6516, 0.9965, 0.5302, 0.4122, 0.2942, 0.3869, 0.6901, 0.9933,
                        0.6036, 0.4966, 0.3896, 0.4669, 0.7285, 0.9900, 0.6770, 0.5810, 0.4850,
                        0.5470, 0.7669, 0.9868, 0.7504, 0.6654, 0.5804, 0.6270, 0.8053, 0.9836,
                        0.3202, 0.3851, 0.4500, 0.5856, 0.7919, 0.9982, 0.4167, 0.4275, 0.4383,
                        0.5482, 0.7572, 0.9663, 0.5133, 0.4700, 0.4267, 0.5109, 0.7226, 0.9344,
                        0.6098, 0.5124, 0.4150, 0.4735, 0.6880, 0.9024, 0.7064, 0.5549, 0.4034,
                        0.4362, 0.6534, 0.8705, 0.8029, 0.5973, 0.3917, 0.3988, 0.6187, 0.8386,
                        0.2282, 0.4134, 0.5986, 0.6065, 0.4372, 0.2678, 0.3632, 0.4955, 0.6279,
                        0.6357, 0.5188, 0.4020, 0.4982, 0.5777, 0.6572, 0.6648, 0.6005, 0.5362,
                        0.6331, 0.6598, 0.6865, 0.6939, 0.6822, 0.6705, 0.7681, 0.7419, 0.7157,
                        0.7230, 0.7639, 0.8047, 0.9031, 0.8241, 0.7450, 0.7522, 0.8455, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.2920, 3.8967, 2.9875, 3.2304, 4.2240, 2.6610, 2.0140, 3.2235, 2.7043,
                            2.5440, 3.3400, 2.1845, 3.0341, 3.3267, 2.2842, 3.1067, 3.2178, 2.1438,
                            1.6815, 3.5478, 2.3330, 1.9165, 3.9037, 2.3262, 3.2967, 3.4461, 3.1963,
                            3.6989, 3.3385, 2.3265, 2.9128, 4.4006, 3.1326, 3.0053, 3.3945, 2.8945};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::BILINEAR;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_REFLECTION;
        alignCorners = false;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5622, 0.6355, 0.6008, 0.4581, 0.4343, 0.5500, 0.5622, 0.6355,
                        0.6008, 0.4581, 0.4343, 0.6405, 0.6519, 0.7201, 0.6771, 0.5230, 0.4973,
                        0.7613, 0.7715, 0.8329, 0.7789, 0.6096, 0.5814, 0.8518, 0.8612, 0.9175,
                        0.8552, 0.6745, 0.6444, 0.8518, 0.8612, 0.9175, 0.8552, 0.6745, 0.6444,
                        0.5927, 0.5778, 0.4885, 0.5906, 0.8840, 0.9329, 0.5927, 0.5778, 0.4885,
                        0.5906, 0.8840, 0.9329, 0.4579, 0.4640, 0.5005, 0.5606, 0.6445, 0.6584,
                        0.2782, 0.3122, 0.5164, 0.5207, 0.3251, 0.2925, 0.1434, 0.1984, 0.5283,
                        0.4907, 0.0855, 0.0180, 0.1434, 0.1984, 0.5283, 0.4907, 0.0855, 0.0180,
                        0.3173, 0.3146, 0.2984, 0.3280, 0.4033, 0.4159, 0.3173, 0.3146, 0.2984,
                        0.3280, 0.4033, 0.4159, 0.4833, 0.4607, 0.3253, 0.3206, 0.4468, 0.4678,
                        0.7046, 0.6555, 0.3611, 0.3109, 0.5048, 0.5371, 0.8706, 0.8017, 0.3880,
                        0.3035, 0.5482, 0.5890, 0.8706, 0.8017, 0.3880, 0.3035, 0.5482, 0.5890,
                        0.3834, 0.3484, 0.1385, 0.3234, 0.9031, 0.9997, 0.3834, 0.3484, 0.1385,
                        0.3234, 0.9031, 0.9997, 0.4935, 0.4626, 0.2774, 0.4278, 0.9139, 0.9949,
                        0.6403, 0.6149, 0.4627, 0.5671, 0.9282, 0.9884, 0.7504, 0.7291, 0.6016,
                        0.6716, 0.9390, 0.9836, 0.7504, 0.7291, 0.6016, 0.6716, 0.9390, 0.9836,
                        0.3202, 0.3364, 0.4337, 0.6371, 0.9466, 0.9982, 0.3202, 0.3364, 0.4337,
                        0.6371, 0.9466, 0.9982, 0.4650, 0.4609, 0.4365, 0.5821, 0.8977, 0.9503,
                        0.6581, 0.6270, 0.4403, 0.5088, 0.8325, 0.8865, 0.8029, 0.7515, 0.4431,
                        0.4538, 0.7836, 0.8386, 0.8029, 0.7515, 0.4431, 0.4538, 0.7836, 0.8386,
                        0.2282, 0.2745, 0.5523, 0.5642, 0.3101, 0.2678, 0.2282, 0.2745, 0.5523,
                        0.5642, 0.3101, 0.2678, 0.4307, 0.4572, 0.6160, 0.6276, 0.4918, 0.4691,
                        0.7006, 0.7007, 0.7010, 0.7121, 0.7339, 0.7376, 0.9031, 0.8833, 0.7648,
                        0.7755, 0.9156, 0.9389, 0.9031, 0.8833, 0.7648, 0.7755, 0.9156, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.5877, 2.6968, 3.9095, 4.1514, 2.7408, 3.2053, 2.3421, 2.1040, 3.4308,
                            3.3573, 2.0751, 2.7010, 3.7347, 2.0529, 3.0647, 3.5380, 2.1913, 2.5319,
                            1.9809, 2.5312, 2.7609, 2.6094, 2.9085, 2.9178, 4.1109, 2.2224, 4.0877,
                            4.3476, 1.8670, 2.6672, 3.5051, 2.8249, 4.0540, 3.7293, 2.2799, 3.3471};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = InterpolationMethod::BILINEAR;
        paddingMode = GridSamplePaddingMode::GRID_SAMPLE_PADDING_REFLECTION;
        alignCorners = true;
        output = _GridSample(_Convert(input, NC4HW4), grid, mode, paddingMode, alignCorners);
        outputNCHW = _Convert(output, NCHW);

        outputPtr = outputNCHW->readMap<float>();

        outputTorch = { 0.5500, 0.5988, 0.6477, 0.6245, 0.5294, 0.4343, 0.6104, 0.6569, 0.7035,
                        0.6767, 0.5765, 0.4763, 0.6707, 0.7150, 0.7593, 0.7289, 0.6236, 0.5183,
                        0.7311, 0.7731, 0.8152, 0.7810, 0.6707, 0.5604, 0.7914, 0.8312, 0.8710,
                        0.8332, 0.7178, 0.6024, 0.8518, 0.8893, 0.9268, 0.8854, 0.7649, 0.6444,
                        0.5927, 0.5332, 0.4737, 0.5417, 0.7373, 0.9329, 0.5028, 0.4992, 0.4956,
                        0.5450, 0.6475, 0.7499, 0.4130, 0.4653, 0.5175, 0.5483, 0.5576, 0.5669,
                        0.3231, 0.4313, 0.5395, 0.5516, 0.4678, 0.3840, 0.2333, 0.3973, 0.5614,
                        0.5549, 0.3780, 0.2010, 0.1434, 0.3634, 0.5833, 0.5582, 0.2881, 0.0180,
                        0.3173, 0.3065, 0.2957, 0.3154, 0.3657, 0.4159, 0.4280, 0.3642, 0.3004,
                        0.3049, 0.3777, 0.4505, 0.5386, 0.4218, 0.3051, 0.2944, 0.3897, 0.4851,
                        0.6493, 0.4795, 0.3097, 0.2838, 0.4018, 0.5198, 0.7599, 0.5372, 0.3144,
                        0.2733, 0.4138, 0.5544, 0.8706, 0.5948, 0.3191, 0.2628, 0.4259, 0.5890,
                        0.3834, 0.2434, 0.1035, 0.2267, 0.6132, 0.9997, 0.4568, 0.3278, 0.1989,
                        0.3068, 0.6516, 0.9965, 0.5302, 0.4122, 0.2942, 0.3869, 0.6901, 0.9933,
                        0.6036, 0.4966, 0.3896, 0.4669, 0.7285, 0.9900, 0.6770, 0.5810, 0.4850,
                        0.5470, 0.7669, 0.9868, 0.7504, 0.6654, 0.5804, 0.6270, 0.8053, 0.9836,
                        0.3202, 0.3851, 0.4500, 0.5856, 0.7919, 0.9982, 0.4167, 0.4275, 0.4383,
                        0.5482, 0.7572, 0.9663, 0.5133, 0.4700, 0.4267, 0.5109, 0.7226, 0.9344,
                        0.6098, 0.5124, 0.4150, 0.4735, 0.6880, 0.9024, 0.7064, 0.5549, 0.4034,
                        0.4362, 0.6534, 0.8705, 0.8029, 0.5973, 0.3917, 0.3988, 0.6187, 0.8386,
                        0.2282, 0.4134, 0.5986, 0.6065, 0.4372, 0.2678, 0.3632, 0.4955, 0.6279,
                        0.6357, 0.5188, 0.4020, 0.4982, 0.5777, 0.6572, 0.6648, 0.6005, 0.5362,
                        0.6331, 0.6598, 0.6865, 0.6939, 0.6822, 0.6705, 0.7681, 0.7419, 0.7157,
                        0.7230, 0.7639, 0.8047, 0.9031, 0.8241, 0.7450, 0.7522, 0.8455, 0.9389};

        for (int i = 0, count = 0; i < outputLen; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());
        inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {shape[0], shape[1], gridShape[1], gridShape[2]}, NCHW), NC4HW4)});

        expectedOutput = {  2.2920, 3.8967, 2.9875, 3.2304, 4.2240, 2.6610, 2.0140, 3.2235, 2.7043,
                            2.5440, 3.3400, 2.1845, 3.0341, 3.3267, 2.2842, 3.1067, 3.2178, 2.1438,
                            1.6815, 3.5478, 2.3330, 1.9165, 3.9037, 2.3262, 3.2967, 3.4461, 3.1963,
                            3.6989, 3.3385, 2.3265, 2.9128, 4.4006, 3.1326, 3.0053, 3.3945, 2.8945};
        gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < inputLen; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\tgrad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        return true;
    }
};

MNNTestSuiteRegister(GridSampleGradTest, "grad/grid_sample");
